#include "alt.hpp"
#include "probs.hpp"
#include "func.hpp"
#include <cstdlib>
#include <cstring>
#include <omp.h>
#include <exception>
#include <gsl/gsl_errno.h>

double norm(double x,double sd)
{
  return std::exp(-x*x/2)/(sd*std::sqrt(2*M_PI));
}

double func(double tau, void*d)
{
  details* dd=(details*)d;
  
  int n=dd->tt.length();
  double*ps;
  try{  ps= new double[(n)*(1+ustate::maxel)];}
  catch(std::bad_alloc&x) { std::cerr<<"BA\n";} 
  

//  std::cerr<<omp_get_thread_num()<<"new"<<ps<<"\n";

  double dnorm=norm((tau-dd->npars[0])/dd->npars[1],dd->npars[1]);
  
  double temp=dd->totaltime-tau;
  
  probs(ps,&temp,dd->pars,&n);
  double r= dnorm*dd->tt.pcor(dd->ff,ps);
  
  delete[] ps;
//  std::cerr<<omp_get_thread_num()<<"del"<<ps<<"\n";

  return r;
}

double totalprob(details*d)
{
//  std::cerr<<omp_get_thread_num()<<" totalprob"<<d<<";";
//  d->print();
//  std::cerr<<"\n";
  gsl_function f;
  f.function=func;
  f.params=(void*)d;

  double r,r2,e;
  
  gsl_integration_workspace *ws;
  ws=gsl_integration_workspace_alloc(400);
  
  int errn=0;
  gsl_set_error_handler_off ();

//  std::cerr<<omp_get_thread_num()<<"ws"<<ws<<"\n";
  errn=  gsl_integration_qagil(&f,d->pars[0],1e-4,0,400,ws,&r,&e);
  errn=  gsl_integration_qagiu(&f,d->pars[0],1e-4,0,400,ws,&r2,&e);
  if(errn!=0) {std::cerr<<omp_get_thread_num()<<"@"<<errn<<"\t"<<d->totaltime;
  for(int i=0;i<d->pars.size();++i) std::cerr<<"\t"<<d->pars[i];
  std::cerr<<"\t"<<r<<"\n";}

  gsl_integration_workspace_free(ws);
//  std::cerr<<omp_get_thread_num()<<" totalprob~"<<d<<"\n";
  
  return r+r2;
}

double allfunc(double tau, void*d)
{
  details* dd=(details*)d;
  
  int n=dd->tt.length();
  double*ps;
  int tel=n*3;
  try{  ps= new double[(n)*(1+ustate::maxel)];}
  catch(std::bad_alloc&x) { std::cerr<<"BA\n";} 
  

//  std::cerr<<omp_get_thread_num()<<"new"<<ps<<"\n";

  double dnorm=norm((tau-dd->npars[0])/dd->npars[1],dd->npars[1]);
  
  double temp=dd->totaltime-tau;
  
  probs(ps,&temp,dd->pars,&n);
  double r=dnorm;
  
  
  for(int i=2;i<tel;i+=3) r*=ps[i];
//  for(int i=0;i<tel;++i) std::cerr<<ps[i]<<" ";
//  std::cerr<<"\n";
  delete[] ps;
//  std::cerr<<omp_get_thread_num()<<"del"<<ps<<"\n";

  return r;
}

double allprob(details*d)
{
//  std::cerr<<"ALLPROB\n";
//  std::cerr<<omp_get_thread_num()<<" totalprob"<<d<<";";
//  d->print();
//  std::cerr<<"\n";
  gsl_function f;
  f.function=allfunc;
  f.params=(void*)d;

  double r,r2,e;
  
  gsl_integration_workspace *ws;
  ws=gsl_integration_workspace_alloc(400);
  
  int errn=0;
  gsl_set_error_handler_off ();

//  std::cerr<<omp_get_thread_num()<<"ws"<<ws<<"\n";
  errn=  gsl_integration_qagil(&f,d->pars[0],1e-4,0,400,ws,&r,&e);
  errn=  gsl_integration_qagiu(&f,d->pars[0],1e-4,0,400,ws,&r2,&e);
  if(errn!=0) {std::cerr<<omp_get_thread_num()<<"@"<<errn<<"\t"<<d->totaltime;
  for(int i=0;i<d->pars.size();++i) std::cerr<<"\t"<<d->pars[i];
  std::cerr<<"\t"<<r<<"\n";}

  gsl_integration_workspace_free(ws);
//  std::cerr<<omp_get_thread_num()<<" totalprob~"<<d<<"\n";
  
  return r+r2;
}

extern "C" void pk_vector(double *pk,double *allpars,const char**target,const char**foil,double * time,int *n)
{
  gsl_set_error_handler_off ();

  std::map<std::string,double> cache;

  for(int i=0;i<*n;++i,++target,++foil,++time,++pk)
  {
    details det;
    det.pars.assign(allpars+2,allpars+4+strlen(*target));
    det.npars.assign(allpars,allpars+2);
  
    std::stringstream ss;
    ss<<*target<<"/"<<*foil<<"/"<<*time;
    std::string spec=ss.str();
    
    std::map<std::string,double>::iterator j=cache.find(spec);
    if(j!=cache.end()) *pk=j->second;
    else
    {
      det.tt=*target;
      det.ff=*foil;
      det.totaltime=*time;
      double temp;
//std::cerr<<*target<<"\t"<<*foil<<"\n";
      if(std::strcmp(*target,*foil)==0) temp=allprob(&det);
      else temp=totalprob(&det);
      *pk=cache[spec]=temp;
    }
//    std::cerr<<spec<<"\t"<<*pk<<"\n";
  }
  
}

double sfunc(double tau,void*d)
{
  details dd=*(details*)d;
//  double temp1=dd->totaltime,  temp2=dd->pars[1];
  dd.totaltime=tau;
  dd.pars[1]=0.;
  double r=(tau>0?1:0)-totalprob(&dd);
//  std::cerr<<"r"<<r<<" "<<tau<<"\t";
//  dd->totaltime=temp1;
//  dd->pars[1]=temp2;
  return r;
}

double totalprime(details*d)//,double ll)
{
//       std::cerr<<omp_get_thread_num()<<" : enteredtp \n";

  gsl_set_error_handler_off ();
  gsl_function f;
  f.function=sfunc;
  f.params=(void*)d;

  double r=0,r2=0,e;
//       std::cerr<<omp_get_thread_num()<<" : tpB \n";
  
  gsl_integration_workspace *ws;
  ws=gsl_integration_workspace_alloc(600);
  int errn;  
  errn=  gsl_integration_qagil(&f,0,1e-2,0,600,ws,&r2,&e);

  if(errn!=0) 
  {std::cerr<<errn;
  for(int j=0;j<2;++j) std::cerr<<"\t"<<d->npars[j];
  for(int i=0;i<6;++i) std::cerr<<"\t"<<d->pars[i];
  std::cerr<<"\n";}
//       std::cerr<<omp_get_thread_num()<<" : tpM \n";

  errn=  gsl_integration_qag(&f,0,d->totaltime,1e-3,0,600,GSL_INTEG_GAUSS61,ws,&r,&e);

  if(errn!=0) {std::cerr<<errn;
  for(int j=0;j<2;++j) std::cerr<<"\t"<<d->npars[j];
  for(int i=0;i<6;++i) std::cerr<<"\t"<<d->pars[i];
  std::cerr<<"\n";}
  gsl_integration_workspace_free(ws);
//  std::cerr<<omp_get_thread_num()<<" : tpY \n";
  
  details dd=*d;
  double temp=dd.pars[1];
  dd.pars[1]=0;

//  std::cerr<<dd.tt<<"/"<<dd.ff<<"\t"<<totalprob(&dd)<<"\t"<<temp<<"\t"<<d->pars[1]<<"\n";
  
  r+=(1-totalprob(&dd))*temp;
//  r-=(1-std::exp(-ll*dd->totaltime))*(1/ll+dd->npars[0]);

//  dd->pars[1]=temp;
//  std::cerr<<omp_get_thread_num()<<" : exitingtp \n";

  return r+r2;
}

static omp_lock_t lock;

inline double prime_vector_th2(double*allpars,char**prime,char**target,double* time,std::map<std::string,double> *cache)
{
  bool cachemine=false;

  double r;

//  std::cerr<<"thread"<<omp_get_thread_num()<<"\n";

  if(cache==0)
  {
    throw "your mother";
    cache=new std::map<std::string,double>;
    cachemine=true;
  }

  {
    details det;
    det.pars.assign(allpars+2,allpars+4+std::strlen(*prime));
    det.npars.assign(allpars,allpars+2);

    double * l=allpars+4, ll=0;
    for(int j=0;j<std::strlen(*prime);++j,++l)
      ll+=*l;    
  
    std::stringstream ss;
    ss<<*prime<<"/"<<*target<<"/"<<*time;
    std::string spec=ss.str();
//    std::cerr<<omp_get_thread_num()<<" : "<<ss.str()<<"\n";
    std::map<std::string,double>::const_iterator j,k;
    double temp;
    bool jk=false;
    omp_set_lock(&lock);
//       std::cerr<<omp_get_thread_num()<<" : enteredlock \n";
    
      j=cache->find(spec);
      k=cache->end();
      if(j!=k){jk=true; temp=j->second;}

    omp_unset_lock(&lock);

//       std::cerr<<omp_get_thread_num()<<" : exitedlock \n";
    
    if(jk) r=temp;
    else
    {
      det.tt=*prime;
      det.ff=*target;
      det.totaltime=*time;

//    std::cerr<<omp_get_thread_num()<<" : "<<det.tt<<"/"<<det.ff<<"/"<<det.totaltime<<"_\n";

      double temp=totalprime(&det);//-(1-std::exp(-ll**time))/ll-allpars[0];

/*      det.tt=*control;

      temp-=totalprime(&det);
*/
      omp_set_lock(&lock);
//       std::cerr<<omp_get_thread_num()<<" : enteredlock2 \n";
      (*cache)[spec]=temp;
      omp_unset_lock(&lock);
//       std::cerr<<omp_get_thread_num()<<" : exitedlock2 \n";
      r=temp;
    }
//    std::cerr<<spec<<"\t"<<*pk<<"\n"; 
  }
    
//  if(cachemine) delete cache;
  return r;
}

void urgh(){ std::cerr<<"urgh"<<omp_get_thread_num()<<"\n"; }

extern "C" void prime_vector_ml2(double*pk,double*allpars,const char**prime,const char**target,const char**control,double* time, int*n)
{
  std::set_terminate(urgh);

  omp_init_lock(&lock);

  gsl_set_error_handler_off ();
  std::map<std::string,double> cache;

//  std::cerr<<*n<<"\n";

  char** primes=new char*[(*n)*2];
  char** targets=new char*[(*n)*2];
  double* times=new double[(*n)*2];
  std::memcpy(primes,prime,sizeof(char*)*(*n));
  std::memcpy(primes+(*n),control,sizeof(char*)*(*n));
  std::memcpy(targets,target,sizeof(char*)*(*n));
  std::memcpy(targets+(*n),target,sizeof(char*)*(*n));
  std::memcpy(times,time,sizeof(double)*(*n));
  std::memcpy(times+(*n),time,sizeof(double)*(*n));
  
//  std::cerr<<"la"<<"\n";

  #pragma omp parallel for schedule(dynamic)
  for(int i=0;i<(*n)*2;++i)
  {
    std::vector<double> parbylen;
//    std::cerr<<primes[i]<<"\n";
    int len=std::strlen(primes[i]);
    int len2=std::strlen(targets[i]);
//    double mid=len*.5-.5;
    double ratio=allpars[5]/(len+allpars[4]);
    {
      parbylen.push_back(allpars[0]);
      parbylen.push_back(allpars[1]);
      parbylen.push_back(allpars[2]);
      parbylen.push_back(allpars[3]);
      parbylen.push_back(ratio*(1+allpars[4]));
      for(int j=1;j<len;++j)
      {
        parbylen.push_back(ratio);
      }
    }
/*
    int len=std::strlen(primes[i]);
    int len2=std::strlen(targets[i]);
    double mid=(len-1)*.5;
    double ratio=allpars[5]/(len+allpars[4]);
    {
      parbylen.push_back(allpars[0]);
      parbylen.push_back(allpars[1]);
      parbylen.push_back(allpars[2]);
      parbylen.push_back(allpars[3]+allpars[6]*len2);
      parbylen.push_back(ratio*(1+allpars[4]));
      for(int j=1;j<len-1;++j)
      {
        parbylen.push_back(ratio);
      }
      parbylen.push_back(ratio);
    }
*/    
    prime_vector_th2(parbylen.data(),primes+i,targets+i,times+i,&cache);
  }
  
//  #pragma omp parallel for
  for(int i=0;i<(*n);++i)
  {
    std::stringstream ss,ss2;
    ss<<prime[i]<<"/"<<target[i]<<"/"<<time[i];
    ss2<<control[i]<<"/"<<target[i]<<"/"<<time[i];
               
    pk[i]=cache[ss.str()]-cache[ss2.str()];
  }

  omp_destroy_lock(&lock);
  
  delete[] primes;
  delete[] targets;
  delete[] times;
}
